// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors

package middleware

import (
	"net/http"
	"sync"
	"time"

	"gitee.com/zhangycsz/glin"
	"golang.org/x/time/rate"
)

// RateLimiterStore is the interface to be implemented by custom stores.
type RateLimiterStore interface {
	// Stores for the rate limiter have to implement the Allow method
	Allow(identifier string) (bool, error)
}

// RateLimiterConfig defines the configuration for the rate limiter
type RateLimiterConfig struct {
	Skipper    Skipper
	BeforeFunc BeforeFunc
	// IdentifierExtractor uses glin.Context to extract the identifier for a visitor
	IdentifierExtractor Extractor
	// Store defines a store for the rate limiter
	Store RateLimiterStore
	// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
	ErrorHandler func(context glin.Context, err error) error
	// DenyHandler provides a handler to be called when RateLimiter denies access
	DenyHandler func(context glin.Context, identifier string, err error) error
}

// Extractor is used to extract data from glin.Context
type Extractor func(context glin.Context) (string, error)

// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
var ErrRateLimitExceeded = glin.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")

// ErrExtractorError denotes an error raised when extractor function is unsuccessful
var ErrExtractorError = glin.NewHTTPError(http.StatusForbidden, "error while extracting identifier")

// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
	Skipper: DefaultSkipper,
	IdentifierExtractor: func(ctx glin.Context) (string, error) {
		id := ctx.RealIP()
		return id, nil
	},
	ErrorHandler: func(context glin.Context, err error) error {
		return &glin.HTTPError{
			Code:     ErrExtractorError.Code,
			Message:  ErrExtractorError.Message,
			Internal: err,
		}
	},
	DenyHandler: func(context glin.Context, identifier string, err error) error {
		return &glin.HTTPError{
			Code:     ErrRateLimitExceeded.Code,
			Message:  ErrRateLimitExceeded.Message,
			Internal: err,
		}
	},
}

/*
RateLimiter returns a rate limiting middleware

	e := glin.New()

	limiterStore := middleware.NewRateLimiterMemoryStore(20)

	e.GET("/rate-limited", func(c glin.Context) error {
		return c.String(http.StatusOK, "test")
	}, RateLimiter(limiterStore))
*/
func RateLimiter(store RateLimiterStore) glin.MiddlewareFunc {
	config := DefaultRateLimiterConfig
	config.Store = store

	return RateLimiterWithConfig(config)
}

/*
RateLimiterWithConfig returns a rate limiting middleware

	e := glin.New()

	config := middleware.RateLimiterConfig{
		Skipper: DefaultSkipper,
		Store: middleware.NewRateLimiterMemoryStore(
			middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
		)
		IdentifierExtractor: func(ctx glin.Context) (string, error) {
			id := ctx.RealIP()
			return id, nil
		},
		ErrorHandler: func(context glin.Context, err error) error {
			return context.JSON(http.StatusTooManyRequests, nil)
		},
		DenyHandler: func(context glin.Context, identifier string) error {
			return context.JSON(http.StatusForbidden, nil)
		},
	}

	e.GET("/rate-limited", func(c glin.Context) error {
		return c.String(http.StatusOK, "test")
	}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) glin.MiddlewareFunc {
	if config.Skipper == nil {
		config.Skipper = DefaultRateLimiterConfig.Skipper
	}
	if config.IdentifierExtractor == nil {
		config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
	}
	if config.ErrorHandler == nil {
		config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
	}
	if config.DenyHandler == nil {
		config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
	}
	if config.Store == nil {
		panic("Store configuration must be provided")
	}
	return func(next glin.HandlerFunc) glin.HandlerFunc {
		return func(c glin.Context) error {
			if config.Skipper(c) {
				return next(c)
			}
			if config.BeforeFunc != nil {
				config.BeforeFunc(c)
			}

			identifier, err := config.IdentifierExtractor(c)
			if err != nil {
				c.Error(config.ErrorHandler(c, err))
				return nil
			}

			if allow, err := config.Store.Allow(identifier); !allow {
				c.Error(config.DenyHandler(c, identifier, err))
				return nil
			}
			return next(c)
		}
	}
}

// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
type RateLimiterMemoryStore struct {
	visitors map[string]*Visitor
	mutex    sync.Mutex
	rate     rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.

	burst       int
	expiresIn   time.Duration
	lastCleanup time.Time

	timeNow func() time.Time
}

// Visitor signifies a unique user's limiter details
type Visitor struct {
	*rate.Limiter
	lastSeen time.Time
}

/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
the provided rate (as req/s).
for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.

Burst and ExpiresIn will be set to default values.

Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate.

Example (with 20 requests/sec):

	limiterStore := middleware.NewRateLimiterMemoryStore(20)
*/
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) {
	return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
		Rate: rate,
	})
}

/*
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of
the configured rate if not provided or set to 0.

The built-in memory store is usually capable for modest loads. For higher loads other
store implementations should be considered.

Characteristics:
* Concurrency above 100 parallel requests may causes measurable lock contention
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
* A high number of requests from a single IP address may cause lock contention

Example:

	limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
		middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute},
	)
*/
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
	store = &RateLimiterMemoryStore{}

	store.rate = config.Rate
	store.burst = config.Burst
	store.expiresIn = config.ExpiresIn
	if config.ExpiresIn == 0 {
		store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
	}
	if config.Burst == 0 {
		store.burst = int(config.Rate)
	}
	store.visitors = make(map[string]*Visitor)
	store.timeNow = time.Now
	store.lastCleanup = store.timeNow()
	return
}

// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
type RateLimiterMemoryStoreConfig struct {
	Rate      rate.Limit    // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
	Burst     int           // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
	ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
}

// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
	ExpiresIn: 3 * time.Minute,
}

// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
	store.mutex.Lock()
	limiter, exists := store.visitors[identifier]
	if !exists {
		limiter = new(Visitor)
		limiter.Limiter = rate.NewLimiter(store.rate, store.burst)
		store.visitors[identifier] = limiter
	}
	now := store.timeNow()
	limiter.lastSeen = now
	if now.Sub(store.lastCleanup) > store.expiresIn {
		store.cleanupStaleVisitors()
	}
	store.mutex.Unlock()
	return limiter.AllowN(store.timeNow(), 1), nil
}

/*
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
of users who haven't visited again after the configured expiry time has elapsed
*/
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
	for id, visitor := range store.visitors {
		if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn {
			delete(store.visitors, id)
		}
	}
	store.lastCleanup = store.timeNow()
}
